{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    },
    "colab": {
      "name": "generate_cfg_patterns.ipynb",
      "provenance": [],
      "collapsed_sections": []
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "TMAuBkINOJRn"
      },
      "source": [
        "import os\n",
        "import re\n",
        "import json\n",
        "import pickle\n",
        "import random\n",
        "from template_config import *\n",
        "from collections import defaultdict\n",
        "from nltk.stem.porter import PorterStemmer\n",
        "from nltk.stem.wordnet import WordNetLemmatizer"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ikglVOlgOJRp"
      },
      "source": [
        "import nltk\n",
        "nltk.download('wordnet')"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ummA7CRTOJRq"
      },
      "source": [
        "ps = PorterStemmer()\n",
        "lmtzr = WordNetLemmatizer()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nfovr_-UOJRr"
      },
      "source": [
        "def read_in_all_data(data_path=DATA_PATH):\n",
        "    training_data = json.load(open(os.path.join(data_path, \"train_spider.json\")))\n",
        "    tables_org = json.load(open(os.path.join(data_path, \"tables.json\")))\n",
        "    tables = {tab['db_id']: tab for tab in tables_org}\n",
        "\n",
        "    return training_data, tables"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8C2NDusXOJRr"
      },
      "source": [
        "def get_all_question_query_pairs(data):\n",
        "    question_query_pairs = []\n",
        "    for item in data:\n",
        "        question_query_pairs.append((item['question_toks'], item['query'], item['db_id']))\n",
        "    return question_query_pairs"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tsSh3SnnOJRs"
      },
      "source": [
        "training_data, tables = read_in_all_data(\"data\")\n",
        "\n",
        "train_qq_pairs = get_all_question_query_pairs(training_data)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JmIDsISeOJRt",
        "outputId": "73bfcb8a-d503-4b0b-a399-262065eab489"
      },
      "source": [
        "print(\"Training question-query pair count: {}\".format(len(train_qq_pairs)))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Training question-query pair count: 7000\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "s1mhvelTOJRu"
      },
      "source": [
        "def is_value(token):\n",
        "    \"\"\"\n",
        "    as values can either be a numerical digit or a string literal, then we can\n",
        "    detect if a token is a value by matching with regex\n",
        "    \"\"\"\n",
        "    is_number = True\n",
        "    try:\n",
        "        float(token)\n",
        "    except ValueError:\n",
        "        is_number = False\n",
        "    is_string = token.startswith(\"\\\"\") or token.startswith(\"\\'\") or token.endswith(\"\\\"\") or token.endswith(\"\\'\")\n",
        "\n",
        "    return is_number or is_string\n",
        "\n",
        "\n",
        "def remove_all_from_clauses(query_keywords):\n",
        "    \"\"\"\n",
        "    remove all keywords from from clauses, until there is no more from clauses\n",
        "    e.g. select {} from {} as {} where {} = {} --> select {} where {} = {}\n",
        "    \"\"\"\n",
        "    # remove from clause by deleting the range from \"FROM\" to \"WHERE\" or \"GROUP\"\n",
        "    start_location = 0\n",
        "    count = 0\n",
        "    while \"FROM\" in query_keywords:\n",
        "        count += 1\n",
        "        if count > 5:\n",
        "            break\n",
        "            print(\"error query_keywords: \", query_keywords)\n",
        "        from_location = query_keywords.index(\"FROM\")\n",
        "        end_token_locations = [len(query_keywords)]  # defaulting to the end of the list\n",
        "        for end_token in [\"WHERE\", \"GROUP\", \"ORDER\"]:\n",
        "            try:\n",
        "                end_token_locations.append(query_keywords.index(end_token, start_location))\n",
        "            except ValueError:\n",
        "                pass\n",
        "\n",
        "        query_keywords = query_keywords[:from_location] + [FROM_SYMBOL] + query_keywords[min(end_token_locations):]\n",
        "        start_location = min(end_token_locations)\n",
        "        \n",
        "    return query_keywords\n",
        "\n",
        "\n",
        "def strip_query(query, table):\n",
        "    \"\"\"\n",
        "    returns (stripped query, non keywords)\n",
        "    \"\"\"\n",
        "    #get table column names info\n",
        "    column_types = table['column_types']\n",
        "    table_names_original = [cn.lower() for cn in table['table_names_original']]\n",
        "    table_names = [cn.lower() for cn in table['table_names']]\n",
        "    column_names = [cn.lower() for i, cn in table['column_names']]\n",
        "    column_names_original = [cn.lower() for i, cn in table['column_names_original']]\n",
        "\n",
        "    #clean query: replace values, numbers, column names with SYMBOL\n",
        "    query_keywords = []\n",
        "    columns = table_names_original + table_names\n",
        "\n",
        "    query = query.replace(\";\",\"\")\n",
        "    query = query.replace(\"\\t\",\"\")\n",
        "    query = query.replace(\"(\", \" ( \").replace(\")\", \" ) \")\n",
        "    # then replace all stuff enclosed by \"\" with a numerical value to get it marked as {VALUE}\n",
        "    str_1 = re.findall(\"\\\"[^\\\"]*\\\"\", query)\n",
        "    str_2 = re.findall(\"\\'[^\\']*\\'\", query)\n",
        "    values = str_1 + str_2\n",
        "    for val in values:\n",
        "        query = query.replace(val.strip(), VALUE_STR_SYMBOL)\n",
        "\n",
        "    query_tokenized = query.split(' ')\n",
        "    float_nums = re.findall(\"[-+]?\\d*\\.\\d+\", query)\n",
        "    query_tokenized = [VALUE_NUM_SYMBOL if qt in float_nums else qt for qt in query_tokenized]\n",
        "    query = \" \".join(query_tokenized)\n",
        "    int_nums = [i.strip() for i in re.findall(\"[^tT]\\d+\", query)]\n",
        "    query_tokenized = [VALUE_NUM_SYMBOL if qt in int_nums else qt for qt in query_tokenized]\n",
        "    nums = float_nums + int_nums\n",
        "        \n",
        "    #query_tokenized = query.split(' ')\n",
        "    cols_dict = {}\n",
        "    for token in query_tokenized:\n",
        "        if len(token.strip()) == 0:  # in case there are more than one space used\n",
        "            continue\n",
        "        if IGNORE_COMMAS_AND_ROUND_BRACKETS:\n",
        "            keywords_dict = SQL_KEYWORDS_AND_OPERATORS_WITHOUT_COMMAS_AND_BRACES\n",
        "        else:\n",
        "            keywords_dict = SQL_KEYWORDS_AND_OPERATORS\n",
        "\n",
        "        if token.upper() not in keywords_dict and token != VALUE_STR_SYMBOL and token != VALUE_NUM_SYMBOL:\n",
        "            token = token.upper()\n",
        "            if USE_COLUMN_AND_VALUE_REPLACEMENT_TOKEN:\n",
        "                token = re.sub(\"[T]\\d+\\.\", '', token)\n",
        "                token = re.sub(r\"\\\"|\\'\", '', token)\n",
        "                token = re.sub(\"[T]\\d+\", '', token).lower()\n",
        "#                 if token in table_names_original:\n",
        "#                     query_keywords.append(TABLE_SYMBOL)\n",
        "#                     continue\n",
        "                if token != '' and token in column_names_original:\n",
        "                    try:\n",
        "                        tok_ind = column_names_original.index(token)\n",
        "                    except:\n",
        "                        print(\"\\ntable: {}\".format(table['db_id']))\n",
        "                        print(\"\\ntoken: {}\".format(token))\n",
        "                        print(\"column_names_original: {}\".format(column_names_original))\n",
        "                        print(\"query: {}\".format(query))\n",
        "                        print(\"query_tokenized: {}\".format(query_tokenized))\n",
        "                        exit(1)\n",
        "                    col_type = column_types[tok_ind]\n",
        "                    col_name = column_names[tok_ind]\n",
        "                    columns.append(col_name)\n",
        "                    columns.append(token)\n",
        "                    if token not in cols_dict:\n",
        "                        cols_dict[token] = COLUMN_SYMBOL.replace(\"}\", str(len(cols_dict)))\n",
        "                    query_keywords.append(cols_dict[token])\n",
        "                elif token in table_names_original:\n",
        "                    query_keywords.append(TABLE_SYMBOL)\n",
        "                    continue\n",
        "                    \n",
        "        else:\n",
        "            query_keywords.append(token.upper())\n",
        "\n",
        "    if \"FROM\" in query_keywords:\n",
        "        query_keywords = remove_all_from_clauses(query_keywords)\n",
        "\n",
        "    if USE_LIMITED_KEYWORD_SET:\n",
        "        query_keywords = [kw for kw in query_keywords if kw in LIMITED_KEYWORD_SET]\n",
        "\n",
        "    columns_lemed = [lmtzr.lemmatize(w) for w in \" \".join(columns).split(\" \") if w not in LOW_CHAR]\n",
        "    columns_lemed_stemed = [ps.stem(w) for w in columns_lemed]\n",
        "\n",
        "    return \" \".join(query_keywords), values, nums, columns_lemed_stemed\n",
        "\n",
        "\n",
        "def filter_string(cs):\n",
        "    return \"\".join([c.upper() for c in cs if c.isalpha() or c == ' '])\n",
        "\n",
        "\n",
        "def process_question(question, values, nums, columns):\n",
        "\n",
        "    question = \" \".join(question).lower()\n",
        "    values = [re.sub(r\"\\\"|\\'\", '', val) for val in values]\n",
        "    for val in values:\n",
        "        val = val.lower()\n",
        "        try:\n",
        "            question = re.sub(r'\\b'+val+r'\\b', VALUE_STR_SYMBOL, question)\n",
        "        except:\n",
        "            continue\n",
        "\n",
        "    for num in nums:\n",
        "        num = num.strip()\n",
        "        question = re.sub(r'\\b'+num+r'\\b', VALUE_NUM_SYMBOL, question)\n",
        "\n",
        "    question_toks = question.split(\" \")\n",
        "    question_lemed = [lmtzr.lemmatize(w) for w in question_toks]\n",
        "    question_lemed_stemed = [ps.stem(w) for w in question_lemed]\n",
        "    replace_inds = [i for i, qt in enumerate(question_lemed_stemed) if qt in columns]\n",
        "    #print(\"question_stemed: {}\".format(question_stemed))\n",
        "    #print(\"replace_inds: {}\".format(replace_inds))\n",
        "    for ind in replace_inds:\n",
        "        question_toks[ind] = COLUMN_SYMBOL\n",
        "\n",
        "    question_template = ' '.join(question_toks)\n",
        "\n",
        "    return question_template"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZCE1SVHuOJRx"
      },
      "source": [
        "KEY_KEYWORD_SET = {\"SELECT\", \"WHERE\", \"GROUP\", \"HAVING\", \"ORDER\", \"BY\", \"LIMIT\", \"EXCEPT\", \"UNION\", \"INTERSECT\"}\n",
        "ALL_KEYWORD_SET = {\"SELECT\", \"WHERE\", \"GROUP\", \"HAVING\", \"DESC\", \"ORDER\", \"BY\", \"LIMIT\", \"EXCEPT\", \"UNION\", \n",
        "                   \"INTERSECT\", \"NOT\", \"IN\", \"OR\", \"LIKE\", \"(\", \">\", \")\", \"COUNT\"}\n",
        "\n",
        "WHERE_OPS = ['=', '>', '<', '>=', '<=', '!=', 'LIKE', 'IS', 'EXISTS']\n",
        "AGG_OPS = ['MAX', 'MIN', 'SUM', 'AVG']\n",
        "DASC = ['ASC', 'DESC']\n",
        "def general_pattern(pattern):\n",
        "    general_pattern_list = []\n",
        "    for x in pattern.split(\" \"):\n",
        "        if x in KEY_KEYWORD_SET:\n",
        "            general_pattern_list.append(x)\n",
        "    \n",
        "    return \" \".join(general_pattern_list)\n",
        "\n",
        "def sub_pattern(pattern):\n",
        "    general_pattern_list = []\n",
        "    for x in pattern.split(\" \"):\n",
        "        if x in ALL_KEYWORD_SET:\n",
        "            general_pattern_list.append(x)\n",
        "    \n",
        "    return \" \".join(general_pattern_list)\n",
        "\n",
        "def tune_pattern(pattern):\n",
        "    general_pattern_list = []\n",
        "    cols_dict = {}\n",
        "    for x in pattern.split(\" \"):\n",
        "        if \"{COLUMN\" in x:\n",
        "            if x not in cols_dict:\n",
        "                cols_dict[x] = COLUMN_SYMBOL.replace(\"}\", str(len(cols_dict))+\"}\")\n",
        "            general_pattern_list.append(cols_dict[x])\n",
        "            continue\n",
        "            \n",
        "        if \"{VALUE\" in x:\n",
        "            general_pattern_list.append(\"{VALUE}\")\n",
        "            continue\n",
        "            \n",
        "        if x == 'DISTINCT':\n",
        "            continue\n",
        "        elif x in DASC:\n",
        "            general_pattern_list.append(\"{DASC}\")\n",
        "        elif x in WHERE_OPS:\n",
        "            general_pattern_list.append(\"{OP}\")\n",
        "        elif x in AGG_OPS:\n",
        "            general_pattern_list.append(\"{AGG}\")\n",
        "        else:\n",
        "            general_pattern_list.append(x)\n",
        "    \n",
        "    return \" \".join(general_pattern_list)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HG_oUG7eOJRy"
      },
      "source": [
        "training_question_pattern_pairs = []\n",
        "training_patterns = set()\n",
        "\n",
        "pattern_question_dict = defaultdict(list)\n",
        "\n",
        "# train_qq_pairs\n",
        "for eid, (question, query, bd_id) in enumerate(train_qq_pairs):\n",
        "    table = tables[bd_id]\n",
        "    if eid % 500 == 0:\n",
        "        print(\"processing eid: \", eid)\n",
        "    \n",
        "    pattern, values, nums, columns = strip_query(query, table)\n",
        "    question_template = process_question(question, values, nums, columns)\n",
        "    \n",
        "    gen_pattern = general_pattern(pattern)\n",
        "    more_pattern = sub_pattern(pattern)\n",
        "    tu_pattern = tune_pattern(pattern)\n",
        "    \n",
        "    pattern_question_dict[tu_pattern].append(' '.join(question) + \" ||| \" + \n",
        "                                              question_template + \" ||| \" + more_pattern\n",
        "                                              + \" ||| \" + query)\n",
        "#     print(\"\\n--------------------------------------\")\n",
        "#     print(\"original question: {}\".format(' '.join(question).encode('utf-8')))\n",
        "#     print(\"question: {}\".format(question_template.encode('utf-8')))\n",
        "#     print(\"query: {}\".format(query))\n",
        "#     print(\"pattern: {}\".format(pattern))\n",
        "#     print(\"values: {}\".format(values))\n",
        "#     print(\"nums: {}\".format(nums))\n",
        "#     print(\"columns: {}\".format(columns))"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "scrolled": true,
        "id": "7cz7OBJqOJRz",
        "outputId": "835ce2ed-9a58-4377-b395-15a96173a11a"
      },
      "source": [
        "print(\"total pattern number: {}\".format(len(pattern_question_dict)))\n",
        "pattern_question_dict = sorted(pattern_question_dict.items(), key=lambda kv: len(kv[1]), reverse=True)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "total pattern number: 517\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "joTm3AK4OJRz"
      },
      "source": [
        "# filter_nums = [762, 275, 241, 204, 202, 164, 98, 59, 55, 48]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "scrolled": true,
        "id": "jNfqLpCsOJR0"
      },
      "source": [
        "for sql, qts in pattern_question_dict:\n",
        "#     if len(qts) not in filter_nums:\n",
        "#         continue\n",
        "    print(\"\\n--------------------------------------------\")\n",
        "    print(\"SQL Pattern: {}\".format(sql))\n",
        "    print(\"count: {}\".format(len(qts)))\n",
        "    for qt in qts:\n",
        "        q, q_template, sql, sql_more = qt.split(\"|||\")\n",
        "        print(\"question: \", q.replace(\"\"\"'\"\"\", \"\").replace(\"\"\"``\"\"\", ''))\n",
        "#         print(\"question: \", q_template.replace(\"\"\"'\"\"\", \"\").replace(\"\"\"``\"\"\", ''))\n",
        "        print(\"SQL: {} \\n\".format(sql_more))\n",
        "#     for qt in qts:\n",
        "#         q, q_template, sql_temp, sql_more = qt.split(\"|||\")\n",
        "#     #     print(\"question: \", q_template)\n",
        "#     #     print(\"sql_temp: \", sql_temp)\n",
        "#     #     print(\"sql_more: \", sql_more)\n",
        "#         if sql == 'SELECT {COLUMN0} {FROM} WHERE {COLUMN4} {OP} {VALUE_STR} AND {COLUMN5} {OP} {VALUE_STR}':\n",
        "#             print(sql_more)\n"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "scrolled": true,
        "id": "c71M2EUVOJR2"
      },
      "source": [
        "for sql, qts in pattern_question_dict:\n",
        "    print(\"\\n\")\n",
        "    print(\"SQL Pattern: {}\".format(sql))\n",
        "    print(\"count: \", len(qts))"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "scrolled": true,
        "id": "6SkVDZm8OJR2"
      },
      "source": [
        "for sql_template, qts in pattern_question_dict:\n",
        "    print(\"\\n--------------------------------------\")\n",
        "    print(\"SQL Pattern: {}\".format(sql_template))\n",
        "    print(\"count: \", len(qts))\n",
        "    sql_dict = defaultdict(int)\n",
        "    for qt in qts:\n",
        "        q, q_template, sql, sql_more = qt.split(\"|||\")\n",
        "        sql_dict[sql] += 1\n",
        "        \n",
        "    sql_count = sorted(sql_dict.items(), key=lambda kv: kv[1])\n",
        "    for sql, count in sql_count:\n",
        "        print(\"SQL: {}, count: {}\".format(sql, count))"
      ],
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "scrolled": true,
        "id": "PSt2Ca_TOJR3"
      },
      "source": [
        "for sql_template, qts in pattern_question_dict:\n",
        "    print(\"\\n--------------------------------------\")\n",
        "    print(\"SQL Pattern: {}\".format(sql_template))\n",
        "    print(\"count: \", len(qts))\n",
        "    for qt in qts:\n",
        "        q, q_template, sql, sql_more = qt.split(\"|||\")\n",
        "        print(\"question: \", q)\n",
        "        print(\"SQL: {} \\n\".format(sql_more))"
      ],
      "execution_count": 6,
      "outputs": []
    }
  ]
}